from __future__ import absolute_import

from numbers import Number

from plotly import exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
from plotly.tools import make_subplots

pd = optional_imports.get_module('pandas')
np = optional_imports.get_module('numpy')
scipy_stats = optional_imports.get_module('scipy.stats')


def calc_stats(data):
    """
    Calculate statistics for use in violin plot.
    """
    x = np.asarray(data, np.float)
    vals_min = np.min(x)
    vals_max = np.max(x)
    q2 = np.percentile(x, 50, interpolation='linear')
    q1 = np.percentile(x, 25, interpolation='lower')
    q3 = np.percentile(x, 75, interpolation='higher')
    iqr = q3 - q1
    whisker_dist = 1.5 * iqr

    # in order to prevent drawing whiskers outside the interval
    # of data one defines the whisker positions as:
    d1 = np.min(x[x >= (q1 - whisker_dist)])
    d2 = np.max(x[x <= (q3 + whisker_dist)])
    return {
        'min': vals_min,
        'max': vals_max,
        'q1': q1,
        'q2': q2,
        'q3': q3,
        'd1': d1,
        'd2': d2
    }


def make_half_violin(x, y, fillcolor='#1f77b4', linecolor='rgb(0, 0, 0)'):
    """
    Produces a sideways probability distribution fig violin plot.
    """
    text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) +
            ', ' + '{:0.2f}'.format(y[i]) + ')'
            for i in range(len(x))]

    return graph_objs.Scatter(
        x=x,
        y=y,
        mode='lines',
        name='',
        text=text,
        fill='tonextx',
        fillcolor=fillcolor,
        line=graph_objs.Line(width=0.5, color=linecolor, shape='spline'),
        hoverinfo='text',
        opacity=0.5
    )


def make_violin_rugplot(vals, pdf_max, distance, color='#1f77b4'):
    """
    Returns a rugplot fig for a violin plot.
    """
    return graph_objs.Scatter(
        y=vals,
        x=[-pdf_max-distance]*len(vals),
        marker=graph_objs.Marker(
            color=color,
            symbol='line-ew-open'
        ),
        mode='markers',
        name='',
        showlegend=False,
        hoverinfo='y'
    )


def make_non_outlier_interval(d1, d2):
    """
    Returns the scatterplot fig of most of a violin plot.
    """
    return graph_objs.Scatter(
        x=[0, 0],
        y=[d1, d2],
        name='',
        mode='lines',
        line=graph_objs.Line(width=1.5,
                             color='rgb(0,0,0)')
    )


def make_quartiles(q1, q3):
    """
    Makes the upper and lower quartiles for a violin plot.
    """
    return graph_objs.Scatter(
        x=[0, 0],
        y=[q1, q3],
        text=['lower-quartile: ' + '{:0.2f}'.format(q1),
              'upper-quartile: ' + '{:0.2f}'.format(q3)],
        mode='lines',
        line=graph_objs.Line(
            width=4,
            color='rgb(0,0,0)'
        ),
        hoverinfo='text'
    )


def make_median(q2):
    """
    Formats the 'median' hovertext for a violin plot.
    """
    return graph_objs.Scatter(
        x=[0],
        y=[q2],
        text=['median: ' + '{:0.2f}'.format(q2)],
        mode='markers',
        marker=dict(symbol='square',
                    color='rgb(255,255,255)'),
        hoverinfo='text'
    )


def make_XAxis(xaxis_title, xaxis_range):
    """
    Makes the x-axis for a violin plot.
    """
    xaxis = graph_objs.XAxis(title=xaxis_title,
                             range=xaxis_range,
                             showgrid=False,
                             zeroline=False,
                             showline=False,
                             mirror=False,
                             ticks='',
                             showticklabels=False)
    return xaxis


def make_YAxis(yaxis_title):
    """
    Makes the y-axis for a violin plot.
    """
    yaxis = graph_objs.YAxis(title=yaxis_title,
                             showticklabels=True,
                             autorange=True,
                             ticklen=4,
                             showline=True,
                             zeroline=False,
                             showgrid=False,
                             mirror=False)
    return yaxis


def violinplot(vals, fillcolor='#1f77b4', rugplot=True):
    """
    Refer to FigureFactory.create_violin() for docstring.
    """
    vals = np.asarray(vals, np.float)
    #  summary statistics
    vals_min = calc_stats(vals)['min']
    vals_max = calc_stats(vals)['max']
    q1 = calc_stats(vals)['q1']
    q2 = calc_stats(vals)['q2']
    q3 = calc_stats(vals)['q3']
    d1 = calc_stats(vals)['d1']
    d2 = calc_stats(vals)['d2']

    # kernel density estimation of pdf
    pdf = scipy_stats.gaussian_kde(vals)
    # grid over the data interval
    xx = np.linspace(vals_min, vals_max, 100)
    # evaluate the pdf at the grid xx
    yy = pdf(xx)
    max_pdf = np.max(yy)
    # distance from the violin plot to rugplot
    distance = (2.0 * max_pdf)/10 if rugplot else 0
    # range for x values in the plot
    plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
    plot_data = [make_half_violin(-yy, xx, fillcolor=fillcolor),
                 make_half_violin(yy, xx, fillcolor=fillcolor),
                 make_non_outlier_interval(d1, d2),
                 make_quartiles(q1, q3),
                 make_median(q2)]
    if rugplot:
        plot_data.append(make_violin_rugplot(vals, max_pdf, distance=distance,
                                             color=fillcolor))
    return plot_data, plot_xrange


def violin_no_colorscale(data, data_header, group_header, colors,
                         use_colorscale, group_stats, rugplot, sort,
                         height, width, title):
    """
    Refer to FigureFactory.create_violin() for docstring.

    Returns fig for violin plot without colorscale.

    """

    # collect all group names
    group_name = []
    for name in data[group_header]:
        if name not in group_name:
            group_name.append(name)
    if sort:
        group_name.sort()

    gb = data.groupby([group_header])
    L = len(group_name)

    fig = make_subplots(rows=1, cols=L,
                        shared_yaxes=True,
                        horizontal_spacing=0.025,
                        print_grid=False)
    color_index = 0
    for k, gr in enumerate(group_name):
        vals = np.asarray(gb.get_group(gr)[data_header], np.float)
        if color_index >= len(colors):
            color_index = 0
        plot_data, plot_xrange = violinplot(vals,
                                            fillcolor=colors[color_index],
                                            rugplot=rugplot)
        layout = graph_objs.Layout()

        for item in plot_data:
            fig.append_trace(item, 1, k + 1)
        color_index += 1

        # add violin plot labels
        fig['layout'].update(
            {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
        )

    # set the sharey axis style
    fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
    fig['layout'].update(
        title=title,
        showlegend=False,
        hovermode='closest',
        autosize=False,
        height=height,
        width=width
    )

    return fig


def violin_colorscale(data, data_header, group_header, colors, use_colorscale,
                      group_stats, rugplot, sort, height, width,
                      title):
    """
    Refer to FigureFactory.create_violin() for docstring.

    Returns fig for violin plot with colorscale.

    """

    # collect all group names
    group_name = []
    for name in data[group_header]:
        if name not in group_name:
            group_name.append(name)
    if sort:
        group_name.sort()

    # make sure all group names are keys in group_stats
    for group in group_name:
        if group not in group_stats:
            raise exceptions.PlotlyError("All values/groups in the index "
                                         "column must be represented "
                                         "as a key in group_stats.")

    gb = data.groupby([group_header])
    L = len(group_name)

    fig = make_subplots(rows=1, cols=L,
                        shared_yaxes=True,
                        horizontal_spacing=0.025,
                        print_grid=False)

    # prepare low and high color for colorscale
    lowcolor = utils.color_parser(colors[0], utils.unlabel_rgb)
    highcolor = utils.color_parser(colors[1], utils.unlabel_rgb)

    # find min and max values in group_stats
    group_stats_values = []
    for key in group_stats:
        group_stats_values.append(group_stats[key])

    max_value = max(group_stats_values)
    min_value = min(group_stats_values)

    for k, gr in enumerate(group_name):
        vals = np.asarray(gb.get_group(gr)[data_header], np.float)

        # find intermediate color from colorscale
        intermed = (group_stats[gr] - min_value) / (max_value - min_value)
        intermed_color = utils.find_intermediate_color(
            lowcolor, highcolor, intermed
        )

        plot_data, plot_xrange = violinplot(
            vals,
            fillcolor='rgb{}'.format(intermed_color),
            rugplot=rugplot
        )
        layout = graph_objs.Layout()

        for item in plot_data:
            fig.append_trace(item, 1, k + 1)
        fig['layout'].update(
            {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
        )
    # add colorbar to plot
    trace_dummy = graph_objs.Scatter(
        x=[0],
        y=[0],
        mode='markers',
        marker=dict(
            size=2,
            cmin=min_value,
            cmax=max_value,
            colorscale=[[0, colors[0]],
                        [1, colors[1]]],
            showscale=True),
        showlegend=False,
    )
    fig.append_trace(trace_dummy, 1, L)

    # set the sharey axis style
    fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
    fig['layout'].update(
        title=title,
        showlegend=False,
        hovermode='closest',
        autosize=False,
        height=height,
        width=width
    )

    return fig


def violin_dict(data, data_header, group_header, colors, use_colorscale,
                group_stats, rugplot, sort, height, width, title):
    """
    Refer to FigureFactory.create_violin() for docstring.

    Returns fig for violin plot without colorscale.

    """

    # collect all group names
    group_name = []
    for name in data[group_header]:
        if name not in group_name:
            group_name.append(name)

    if sort:
        group_name.sort()

    # check if all group names appear in colors dict
    for group in group_name:
        if group not in colors:
            raise exceptions.PlotlyError("If colors is a dictionary, all "
                                         "the group names must appear as "
                                         "keys in colors.")

    gb = data.groupby([group_header])
    L = len(group_name)

    fig = make_subplots(rows=1, cols=L,
                        shared_yaxes=True,
                        horizontal_spacing=0.025,
                        print_grid=False)

    for k, gr in enumerate(group_name):
        vals = np.asarray(gb.get_group(gr)[data_header], np.float)
        plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr],
                                            rugplot=rugplot)
        layout = graph_objs.Layout()

        for item in plot_data:
            fig.append_trace(item, 1, k + 1)

        # add violin plot labels
        fig['layout'].update(
            {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
        )

    # set the sharey axis style
    fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
    fig['layout'].update(
        title=title,
        showlegend=False,
        hovermode='closest',
        autosize=False,
        height=height,
        width=width
    )

    return fig


def create_violin(data, data_header=None, group_header=None, colors=None,
                  use_colorscale=False, group_stats=None, rugplot=True,
                  sort=False, height=450, width=600,
                  title='Violin and Rug Plot'):
    """
    Returns figure for a violin plot

    :param (list|array) data: accepts either a list of numerical values,
        a list of dictionaries all with identical keys and at least one
        column of numeric values, or a pandas dataframe with at least one
        column of numbers.
    :param (str) data_header: the header of the data column to be used
        from an inputted pandas dataframe. Not applicable if 'data' is
        a list of numeric values.
    :param (str) group_header: applicable if grouping data by a variable.
        'group_header' must be set to the name of the grouping variable.
    :param (str|tuple|list|dict) colors: either a plotly scale name,
        an rgb or hex color, a color tuple, a list of colors or a
        dictionary. An rgb color is of the form 'rgb(x, y, z)' where
        x, y and z belong to the interval [0, 255] and a color tuple is a
        tuple of the form (a, b, c) where a, b and c belong to [0, 1].
        If colors is a list, it must contain valid color types as its
        members.
    :param (bool) use_colorscale: only applicable if grouping by another
        variable. Will implement a colorscale based on the first 2 colors
        of param colors. This means colors must be a list with at least 2
        colors in it (Plotly colorscales are accepted since they map to a
        list of two rgb colors). Default = False
    :param (dict) group_stats: a dictioanry where each key is a unique
        value from the group_header column in data. Each value must be a
        number and will be used to color the violin plots if a colorscale
        is being used.
    :param (bool) rugplot: determines if a rugplot is draw on violin plot.
        Default = True
    :param (bool) sort: determines if violins are sorted
        alphabetically (True) or by input order (False). Default = False
    :param (float) height: the height of the violin plot.
    :param (float) width: the width of the violin plot.
    :param (str) title: the title of the violin plot.

    Example 1: Single Violin Plot
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_violin
    from plotly.graph_objs import graph_objs

    import numpy as np
    from scipy import stats

    # create list of random values
    data_list = np.random.randn(100)
    data_list.tolist()

    # create violin fig
    fig = create_violin(data_list, colors='#604d9e')

    # plot
    py.iplot(fig, filename='Violin Plot')
    ```

    Example 2: Multiple Violin Plots with Qualitative Coloring
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_violin
    from plotly.graph_objs import graph_objs

    import numpy as np
    import pandas as pd
    from scipy import stats

    # create dataframe
    np.random.seed(619517)
    Nr=250
    y = np.random.randn(Nr)
    gr = np.random.choice(list("ABCDE"), Nr)
    norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]

    for i, letter in enumerate("ABCDE"):
        y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
    df = pd.DataFrame(dict(Score=y, Group=gr))

    # create violin fig
    fig = create_violin(df, data_header='Score', group_header='Group',
                        sort=True, height=600, width=1000)

    # plot
    py.iplot(fig, filename='Violin Plot with Coloring')
    ```

    Example 3: Violin Plots with Colorscale
    ```
    import plotly.plotly as py
    from plotly.figure_factory import create_violin
    from plotly.graph_objs import graph_objs

    import numpy as np
    import pandas as pd
    from scipy import stats

    # create dataframe
    np.random.seed(619517)
    Nr=250
    y = np.random.randn(Nr)
    gr = np.random.choice(list("ABCDE"), Nr)
    norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]

    for i, letter in enumerate("ABCDE"):
        y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
    df = pd.DataFrame(dict(Score=y, Group=gr))

    # define header params
    data_header = 'Score'
    group_header = 'Group'

    # make groupby object with pandas
    group_stats = {}
    groupby_data = df.groupby([group_header])

    for group in "ABCDE":
        data_from_group = groupby_data.get_group(group)[data_header]
        # take a stat of the grouped data
        stat = np.median(data_from_group)
        # add to dictionary
        group_stats[group] = stat

    # create violin fig
    fig = create_violin(df, data_header='Score', group_header='Group',
                        height=600, width=1000, use_colorscale=True,
                        group_stats=group_stats)

    # plot
    py.iplot(fig, filename='Violin Plot with Colorscale')
    ```
    """

    # Validate colors
    if isinstance(colors, dict):
        valid_colors = utils.validate_colors_dict(colors, 'rgb')
    else:
        valid_colors = utils.validate_colors(colors, 'rgb')

    # validate data and choose plot type
    if group_header is None:
        if isinstance(data, list):
            if len(data) <= 0:
                raise exceptions.PlotlyError("If data is a list, it must be "
                                             "nonempty and contain either "
                                             "numbers or dictionaries.")

            if not all(isinstance(element, Number) for element in data):
                raise exceptions.PlotlyError("If data is a list, it must "
                                             "contain only numbers.")

        if pd and isinstance(data, pd.core.frame.DataFrame):
            if data_header is None:
                raise exceptions.PlotlyError("data_header must be the "
                                             "column name with the "
                                             "desired numeric data for "
                                             "the violin plot.")

            data = data[data_header].values.tolist()

        # call the plotting functions
        plot_data, plot_xrange = violinplot(data, fillcolor=valid_colors[0],
                                            rugplot=rugplot)

        layout = graph_objs.Layout(
            title=title,
            autosize=False,
            font=graph_objs.Font(size=11),
            height=height,
            showlegend=False,
            width=width,
            xaxis=make_XAxis('', plot_xrange),
            yaxis=make_YAxis(''),
            hovermode='closest'
        )
        layout['yaxis'].update(dict(showline=False,
                                    showticklabels=False,
                                    ticks=''))

        fig = graph_objs.Figure(data=graph_objs.Data(plot_data),
                                layout=layout)

        return fig

    else:
        if not isinstance(data, pd.core.frame.DataFrame):
            raise exceptions.PlotlyError("Error. You must use a pandas "
                                         "DataFrame if you are using a "
                                         "group header.")

        if data_header is None:
            raise exceptions.PlotlyError("data_header must be the column "
                                         "name with the desired numeric "
                                         "data for the violin plot.")

        if use_colorscale is False:
            if isinstance(valid_colors, dict):
                # validate colors dict choice below
                fig = violin_dict(
                    data, data_header, group_header, valid_colors,
                    use_colorscale, group_stats, rugplot, sort,
                    height, width, title
                )
                return fig
            else:
                fig = violin_no_colorscale(
                    data, data_header, group_header, valid_colors,
                    use_colorscale, group_stats, rugplot, sort,
                    height, width, title
                )
                return fig
        else:
            if isinstance(valid_colors, dict):
                raise exceptions.PlotlyError("The colors param cannot be "
                                             "a dictionary if you are "
                                             "using a colorscale.")

            if len(valid_colors) < 2:
                raise exceptions.PlotlyError("colors must be a list with "
                                             "at least 2 colors. A "
                                             "Plotly scale is allowed.")

            if not isinstance(group_stats, dict):
                raise exceptions.PlotlyError("Your group_stats param "
                                             "must be a dictionary.")

            fig = violin_colorscale(
                data, data_header, group_header, valid_colors,
                use_colorscale, group_stats, rugplot, sort, height,
                width, title
            )
            return fig
